{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02165.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02165.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2u6PP0Pru8tB",
        "colab_type": "text"
      },
      "source": [
        "## 5.8 网络中的网络 ( NiN )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DybwHrEixkPv",
        "colab_type": "text"
      },
      "source": [
        "### 5.8.1 NiN 块"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Dls6nqZGxu8p",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "from torch import nn, optim\n",
        "import d2l \n",
        "import time \n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \n",
        "\n",
        "\n",
        "def nin_block(in_channels, out_channels, kernel_size, stride, padding):\n",
        "    blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), \n",
        "               nn.ReLU(), \n",
        "               nn.Conv2d(out_channels, out_channels, kernel_size=1), \n",
        "               nn.ReLU(), \n",
        "               nn.Conv2d(out_channels, out_channels, kernel_size=1), \n",
        "               nn.ReLU())\n",
        "    return blk "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X02YzCveyvgS",
        "colab_type": "text"
      },
      "source": [
        "### 5.8.2 NiN 模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VpgOneDEyzDx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "net = nn.Sequential(\n",
        "    nin_block(1, 96, kernel_size=11, stride=4, padding=0), \n",
        "    nn.MaxPool2d(kernel_size=3, stride=2), \n",
        "    nin_block(96, 256, kernel_size=5, stride=1, padding=2),\n",
        "    nn.MaxPool2d(kernel_size=3, stride=2), \n",
        "    nin_block(256, 384, kernel_size=3, stride=1, padding=1), \n",
        "    nn.MaxPool2d(kernel_size=3, stride=2), \n",
        "    nn.Dropout(0.5), \n",
        "    nin_block(384, 10, kernel_size=3, stride=1, padding=1), \n",
        "    nn.AvgPool2d(kernel_size=5),\n",
        "    d2l.FlattenLayer()\n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Dwna_Ihn0Z52",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 185
        },
        "outputId": "a5dbc811-0120-4e5e-adda-a8d6a4bead30"
      },
      "source": [
        "X = torch.rand(1, 1, 224, 224)\n",
        "for name, blk in net.named_children():\n",
        "    X = blk(X)\n",
        "    print(name, 'output shape: ', X.shape)"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0 output shape:  torch.Size([1, 96, 54, 54])\n",
            "1 output shape:  torch.Size([1, 96, 26, 26])\n",
            "2 output shape:  torch.Size([1, 256, 26, 26])\n",
            "3 output shape:  torch.Size([1, 256, 12, 12])\n",
            "4 output shape:  torch.Size([1, 384, 12, 12])\n",
            "5 output shape:  torch.Size([1, 384, 5, 5])\n",
            "6 output shape:  torch.Size([1, 384, 5, 5])\n",
            "7 output shape:  torch.Size([1, 10, 5, 5])\n",
            "8 output shape:  torch.Size([1, 10, 1, 1])\n",
            "9 output shape:  torch.Size([1, 10])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mh0_gdht0sqa",
        "colab_type": "text"
      },
      "source": [
        "### 5.8.3 获取数据和训练模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "za629DBd2TzS",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 118
        },
        "outputId": "8f752ecd-0769-4933-bb25-0ac4d0feeabf"
      },
      "source": [
        "batch_size = 128 \n",
        "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n",
        "\n",
        "lr, num_epochs = 0.002, 5 \n",
        "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
        "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "training on  cuda\n",
            "epoch 1, loss 2.3037, train acc 0.099, test acc 0.100, time 142.3 sec\n",
            "epoch 2, loss 2.3026, train acc 0.100, test acc 0.100, time 142.2 sec\n",
            "epoch 3, loss 2.3026, train acc 0.100, test acc 0.100, time 142.1 sec\n",
            "epoch 4, loss 2.3026, train acc 0.100, test acc 0.100, time 141.5 sec\n",
            "epoch 5, loss 2.3026, train acc 0.100, test acc 0.100, time 142.6 sec\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}